{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhoQ0WE77laV"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_ckMIh7O7s6D"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM"
      },
      "source": [
        "# 使用 tf.distribute.Strategy 进行自定义训练"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/distribute/custom_training\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 TensorFlow.org 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/custom_training.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 上运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/custom_training.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/distribute/custom_training.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载该 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbVhjPpzn6BM"
      },
      "source": [
        "本教程演示了如何使用 [`tf.distribute.Strategy`](https://tensorflow.google.cn/guide/distribute_strategy) 来进行自定义训练循环。 我们将在流行的 MNIST 数据集上训练一个简单的 CNN 模型。 流行的 MNIST 数据集包含了 60000 张尺寸为 28 x 28 的训练图像和 10000 张尺寸为 28 x 28 的测试图像。 \n",
        "\n",
        "我们用自定义训练循环来训练我们的模型是因为它们在训练的过程中为我们提供了灵活性和在训练过程中更好的控制。而且，使它们调试模型和训练循环的时候更容易。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "outputs": [],
      "source": [
        "# 导入 TensorFlow\n",
        "import tensorflow as tf\n",
        "\n",
        "# 帮助库\n",
        "import numpy as np\n",
        "import os\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MM6W__qraV55"
      },
      "source": [
        "## 下载流行的 MNIST 数据集"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7MqDQO0KCaWS"
      },
      "outputs": [],
      "source": [
        "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
        "\n",
        "# 向数组添加维度 -> 新的维度 == (28, 28, 1)\n",
        "# 我们这样做是因为我们模型中的第一层是卷积层\n",
        "# 而且它需要一个四维的输入 (批大小, 高, 宽, 通道).\n",
        "# 批大小维度稍后将添加。\n",
        "train_images = train_images[..., None]\n",
        "test_images = test_images[..., None]\n",
        "\n",
        "# 获取[0,1]范围内的图像。\n",
        "train_images = train_images / np.float32(255)\n",
        "test_images = test_images / np.float32(255)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AXoHhrsbdF3"
      },
      "source": [
        "## 创建一个分发变量和图形的策略"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mVuLZhbem8d"
      },
      "source": [
        "`tf.distribute.MirroredStrategy` 策略是如何运作的？\n",
        "\n",
        "*   所有变量和模型图都复制在副本上。\n",
        "*   输入都均匀分布在副本中。\n",
        "*   每个副本在收到输入后计算输入的损失和梯度。\n",
        "*   通过求和，每一个副本上的梯度都能同步。\n",
        "*   同步后，每个副本上的复制的变量都可以同样更新。\n",
        "\n",
        "注意：您可以将下面的所有代码放在一个单独单元内。 我们将它分成几个代码单元用于说明目的。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F2VeZUWUj5S4"
      },
      "outputs": [],
      "source": [
        "# 如果设备未在 `tf.distribute.MirroredStrategy` 的指定列表中，它会被自动检测到。\n",
        "strategy = tf.distribute.MirroredStrategy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZngeM_2o0_JO"
      },
      "outputs": [],
      "source": [
        "print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k53F5I_IiGyI"
      },
      "source": [
        "## 设置输入流水线"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Qb6nDgxiN_n"
      },
      "source": [
        "将图形和变量导出成平台不可识别的 SavedModel 格式。在你的模型保存后，你可以在有或没有范围的情况下载入它。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jwJtsCQhHK-E"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(train_images)\n",
        "\n",
        "BATCH_SIZE_PER_REPLICA = 64\n",
        "GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync\n",
        "\n",
        "EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7fj3GskHC8g"
      },
      "source": [
        "创建数据集并分发它们："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WYrMNNDhAvVl"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) \n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) \n",
        "\n",
        "train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)\n",
        "test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAXAo_wWbWSb"
      },
      "source": [
        "## 创建模型\n",
        "\n",
        "使用 `tf.keras.Sequential` 创建一个模型。你也可以使用模型子类化 API 来完成这个。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ODch-OFCaW4"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Conv2D(64, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(64, activation='relu'),\n",
        "      tf.keras.layers.Dense(10, activation='softmax')\n",
        "    ])\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9iagoTBfijUz"
      },
      "outputs": [],
      "source": [
        "# 创建检查点目录以存储检查点。\n",
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-wlFFZbP33n"
      },
      "source": [
        "## 定义损失函数\n",
        "\n",
        "通常，在一台只有一个 GPU / CPU 的机器上，损失需要除去输入批量中的示例数。\n",
        "\n",
        "*那么，使用 `tf.distribute.Strategy` 时应该如何计算损失？*\n",
        "\n",
        "* 举一个例子，假设您有4个 GPU，批量大小为64. 输入的一个批次分布在各个副本（4个 GPU）上，每个副本获得的输入大小为16。\n",
        "\n",
        "* 每个副本上的模型使用其各自的输入执行正向传递并计算损失。 现在，相较于将损耗除以其各自输入中的示例数（BATCH_SIZE_PER_REPLICA = 16），应将损失除以GLOBAL_BATCH_SIZE（64）。\n",
        "\n",
        "*为什么这样做？*\n",
        "\n",
        "* 需要这样做是因为在每个副本上计算梯度之后，它们通过 **summing** 来使得在自身在各个副本之间同步。\n",
        "\n",
        "*如何在TensorFlow中执行此操作？*\n",
        "* 如果您正在编写自定义训练循环，如本教程中所示，您应该将每个示例损失相加并将总和除以 GLOBAL_BATCH_SIZE ：\n",
        "`scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)` 或者你可以使用`tf.nn.compute_average_loss` 来获取每个示例的损失，可选的样本权重，将GLOBAL_BATCH_SIZE作为参数，并返回缩放的损失。\n",
        "\n",
        "* 如果您在模型中使用正则化损失，则需要进行缩放多个副本的损失。 您可以使用`tf.nn.scale_regularization_loss`函数执行此操作。\n",
        "\n",
        "* 建议不要使用`tf.reduce_mean`。 这样做会将损失除以实际的每个副本中每一步都会改变的批次大小。\n",
        "\n",
        "* 这种缩小和缩放是在 keras中 `modelcompile`和`model.fit`中自动完成的\n",
        "\n",
        "* 如果使用`tf.keras.losses`类（如下面这个例子所示），则需要将损失减少明确指定为“NONE”或者“SUM”。 使用 `tf.distribute.Strategy` 时，`AUTO`和`SUM_OVER_BATCH_SIZE` 是不能使用的。 不能使用 `AUTO` 是因为用户应明确考虑到在分布式情况下他们想做的哪些减少是正确的。不能使用`SUM_OVER_BATCH_SIZE`是因为目前它只按每个副本批次大小进行划分，并按照用户的副本数进行划分，这导致了它们很容易丢失。 因此，我们要求用户要明确这些减少。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R144Wci782ix"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # 将减少设置为“无”，以便我们可以在之后进行这个减少并除以全局批量大小。\n",
        "  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "      reduction=tf.keras.losses.Reduction.NONE)\n",
        "  # 或者使用 loss_fn = tf.keras.losses.sparse_categorical_crossentropy\n",
        "  def compute_loss(labels, predictions):\n",
        "    per_example_loss = loss_object(labels, predictions)\n",
        "    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w8y54-o9T2Ni"
      },
      "source": [
        "## 定义衡量指标以跟踪损失和准确性\n",
        "\n",
        "这些指标可以跟踪测试的损失，训练和测试的准确性。 您可以使用`.result()`随时获取累积的统计信息。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zt3AHb46Tr3w"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
        "\n",
        "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='train_accuracy')\n",
        "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='test_accuracy')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iuKuNXPORfqJ"
      },
      "source": [
        "## 训练循环"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrMmakq5EqeQ"
      },
      "outputs": [],
      "source": [
        "# 必须在`strategy.scope`下创建模型和优化器。\n",
        "with strategy.scope():\n",
        "  model = create_model()\n",
        "\n",
        "  optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3UX43wUu04EL"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  def train_step(inputs):\n",
        "    images, labels = inputs\n",
        "\n",
        "    with tf.GradientTape() as tape:\n",
        "      predictions = model(images, training=True)\n",
        "      loss = compute_loss(labels, predictions)\n",
        "\n",
        "    gradients = tape.gradient(loss, model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "    train_accuracy.update_state(labels, predictions)\n",
        "    return loss \n",
        "\n",
        "  def test_step(inputs):\n",
        "    images, labels = inputs\n",
        "\n",
        "    predictions = model(images, training=False)\n",
        "    t_loss = loss_object(labels, predictions)\n",
        "\n",
        "    test_loss.update_state(t_loss)\n",
        "    test_accuracy.update_state(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gX975dMSNw0e"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # `experimental_run_v2`将复制提供的计算并使用分布式输入运行它。\n",
        "  @tf.function\n",
        "  def distributed_train_step(dataset_inputs):\n",
        "    per_replica_losses = strategy.experimental_run_v2(train_step,\n",
        "                                                      args=(dataset_inputs,))\n",
        "    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,\n",
        "                           axis=None)\n",
        " \n",
        "  @tf.function\n",
        "  def distributed_test_step(dataset_inputs):\n",
        "    return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))\n",
        "\n",
        "  for epoch in range(EPOCHS):\n",
        "    # 训练循环\n",
        "    total_loss = 0.0\n",
        "    num_batches = 0\n",
        "    for x in train_dist_dataset:\n",
        "      total_loss += distributed_train_step(x)\n",
        "      num_batches += 1\n",
        "    train_loss = total_loss / num_batches\n",
        "\n",
        "    # 测试循环\n",
        "    for x in test_dist_dataset:\n",
        "      distributed_test_step(x)\n",
        "\n",
        "    if epoch % 2 == 0:\n",
        "      checkpoint.save(checkpoint_prefix)\n",
        "\n",
        "    template = (\"Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, \"\n",
        "                \"Test Accuracy: {}\")\n",
        "    print (template.format(epoch+1, train_loss,\n",
        "                           train_accuracy.result()*100, test_loss.result(),\n",
        "                           test_accuracy.result()*100))\n",
        "\n",
        "    test_loss.reset_states()\n",
        "    train_accuracy.reset_states()\n",
        "    test_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z1YvXqOpwy08"
      },
      "source": [
        "以上示例中需要注意的事项：\n",
        "\n",
        "* 我们使用`for x in ...`迭代构造`train_dist_dataset`和`test_dist_dataset`。\n",
        "* 缩放损失是`distributed_train_step`的返回值。 这个值会在各个副本使用`tf.distribute.Strategy.reduce`的时候合并，然后通过`tf.distribute.Strategy.reduce`叠加各个返回值来跨批次。\n",
        "* 在执行`tf.distribute.Strategy.experimental_run_v2`时，`tf.keras.Metrics`应在`train_step`和`test_step`中更新。\n",
        "* `tf.distribute.Strategy.experimental_run_v2`返回策略中每个本地副本的结果，并且有多种方法可以处理此结果。 您可以执行`tf.distribute.Strategy.reduce`来获取汇总值。 您还可以执行`tf.distribute.Strategy.experimental_local_results`来获取每个本地副本中结果中包含的值列表。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-q5qp31IQD8t"
      },
      "source": [
        "## 恢复最新的检查点并进行测试"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNW2P00bkMGJ"
      },
      "source": [
        "一个模型使用了`tf.distribute.Strategy`的检查点可以使用策略或者不使用策略进行恢复。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pg3B-Cw_cn3a"
      },
      "outputs": [],
      "source": [
        "eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='eval_accuracy')\n",
        "\n",
        "new_model = create_model()\n",
        "new_optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7qYii7KUYiSM"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def eval_step(images, labels):\n",
        "  predictions = new_model(images, training=False)\n",
        "  eval_accuracy(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LeZ6eeWRoUNq"
      },
      "outputs": [],
      "source": [
        "checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
        "\n",
        "for images, labels in test_dataset:\n",
        "  eval_step(images, labels)\n",
        "\n",
        "print ('Accuracy after restoring the saved model without strategy: {}'.format(\n",
        "    eval_accuracy.result()*100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EbcI87EEzhzg"
      },
      "source": [
        "## 迭代一个数据集的替代方法\n",
        "\n",
        "### 使用迭代器\n",
        "\n",
        "如果你想要迭代一个已经给定步骤数量而不需要整个遍历的数据集，你可以创建一个迭代器并在迭代器上调用`iter`和显式调用`next`。 您可以选择在 tf.function 内部和外部迭代数据集。 这是一个小片段，演示了使用迭代器在 tf.function 外部迭代数据集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7c73wGC00CzN"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  for _ in range(EPOCHS):\n",
        "    total_loss = 0.0\n",
        "    num_batches = 0\n",
        "    train_iter = iter(train_dist_dataset)\n",
        "\n",
        "    for _ in range(10):\n",
        "      total_loss += distributed_train_step(next(train_iter))\n",
        "      num_batches += 1\n",
        "    average_train_loss = total_loss / num_batches\n",
        "\n",
        "    template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "    print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))\n",
        "    train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxVp48Oy0m6y"
      },
      "source": [
        "### 在 tf.function 中迭代\n",
        "\n",
        "您还可以使用`for x in ...`构造在 tf.function 内部迭代整个输入`train_dist_dataset`，或者像上面那样创建迭代器。下面的例子演示了在 tf.function 中包装一个 epoch 并在功能内迭代`train_dist_dataset`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-REzmcXv00qm"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  @tf.function\n",
        "  def distributed_train_epoch(dataset):\n",
        "    total_loss = 0.0\n",
        "    num_batches = 0\n",
        "    for x in dataset:\n",
        "      per_replica_losses = strategy.experimental_run_v2(train_step,\n",
        "                                                        args=(x,))\n",
        "      total_loss += strategy.reduce(\n",
        "        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
        "      num_batches += 1\n",
        "    return total_loss / tf.cast(num_batches, dtype=tf.float32)\n",
        "\n",
        "  for epoch in range(EPOCHS):\n",
        "    train_loss = distributed_train_epoch(train_dist_dataset)\n",
        "\n",
        "    template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "    print (template.format(epoch+1, train_loss, train_accuracy.result()*100))\n",
        "\n",
        "    train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MuZGXiyC7ABR"
      },
      "source": [
        "### 跟踪副本中的训练的损失\n",
        "\n",
        "注意：作为通用的规则，您应该使用`tf.keras.Metrics`来跟踪每个样本的值以避免它们在副本中合并。\n",
        "\n",
        "我们 *不* 建议使用`tf.metrics.Mean` 来跟踪不同副本的训练损失，因为在执行过程中会进行损失缩放计算。\n",
        "\n",
        "例如，如果您运行具有以下特点的训练作业：\n",
        "* 两个副本\n",
        "* 在每个副本上处理两个例子\n",
        "* 产生的损失值：每个副本为[2,3]和[4,5]\n",
        "* 全局批次大小 = 4\n",
        "\n",
        "通过损失缩放，您可以通过添加损失值来计算每个副本上的每个样本的损失值，然后除以全局批量大小。 在这种情况下：`（2 + 3）/ 4 = 1.25`和`（4 + 5）/ 4 = 2.25`。\n",
        "\n",
        "如果您使用 `tf.metrics.Mean` 来跟踪两个副本的损失，结果会有所不同。 在这个例子中，你最终得到一个`total`为 3.50 和`count`为 2 的结果，当调用`result（）`时，你将得到`total` /`count` = 1.75。 使用`tf.keras.Metrics`计算损失时会通过一个等于同步副本数量的额外因子来缩放。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xisYJaV9KZTN"
      },
      "source": [
        "### 例子和教程\n",
        "以下是一些使用自定义训练循环来分发策略的示例：\n",
        "\n",
        "1. [教程](training_loops.ipynb) 使用 `MirroredStrategy` 来训练 MNIST 。\n",
        "2. [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) 使用 `MirroredStrategy`的例子。\n",
        "1. [BERT](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) 使用 `MirroredStrategy` 和`TPUStrategy`来训练的例子。\n",
        "此示例对于了解如何在分发训练过程中如何载入一个检测点和定期生成检查点特别有帮助。\n",
        "2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) 使用 `MirroredStrategy` 来启用 `keras_use_ctl` 标记。\n",
        "3. [NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) 使用 `MirroredStrategy`来训练的例子。\n",
        "\n",
        "更多的例子列在 [分发策略指南](../../guide/distribute_strategy.ipynb#examples_and_tutorials)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hEJNsokjOKs"
      },
      "source": [
        "## 下一步\n",
        "\n",
        "在你的模型上尝试新的`tf.distribute.Strategy` API。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_training.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
